""" Import libraries """
from casadi import *
import numpy as np
import pandas as pd
import copy
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('Qt5Agg')
import time

from model.model_helper import ModelHelper
from model.env_rl_helper import RLModelHelper
from utils.init_net_helper import ini_net, save_missing_arguments

"""
The code serves as the main file to simulate state trajectories of a process.
Extra functions can be added for a more advanced simulation
1. Solving a steady state optimization, before runs a simulation (for-loop)
2. Solving a control problem, within the for-loop, before applying uk onto the process

Input: 
    f(x,u,d) or f(x,z,u,d): a continuous ODE or DAE. The disturbance d can be combined with u yet constant.
    dt, tf, Nt: temporal parameters. Needs 2 out of 3. The last one is dependent on defined 2.
    metricCalc: to calculate metrics - the objective function of the optimization problem
    
Output: a steady state and state trajectories
"""
# todo: try to make it generic so we can use it in the future without too many modifications

""" Input """
# ---------------------------------------------------------------------
# Simulation parameters
# ---------------------------------------------------------------------
dt = 0.1  # Plant simulation time step or sampling time, [min] # must be same as the sampling time in the data generation todo: this was defined in model_helper
tf = dt*200  # Total simulation time, [min]
eps = 1E-30  # A very small number
Nt = int(round(tf / dt))  # Number of simulation points

# ---------------------------------------------------------------------
# Import initial values/guess of the plant - This is the section connecting with data
# ---------------------------------------------------------------------
# For steady state optimization, these are initial guess. For process simulation, these are initial states
xs = np.array([0.52917698, 348.37278815])   # Steady state values are from online implementation of ECO-RL seed 0
us = np.array([300.51267918])

# ---------------------------------------------------------------------
# Model parameters
# ---------------------------------------------------------------------
Nx = 2  # Number of differential states
Nu = 1  # Number of inputs

# ---------------------------------------------------------------------
# Utilities for simulation that will be used many times. Define outputs identified by Carbon Clean
# ---------------------------------------------------------------------
def metricCalc(x, u):
    return x[0]


""" ------------------------------------------------- Input ends ------------------------------------------------- """

# ---------------------------------------------------------------------
# Model setup
# ---------------------------------------------------------------------
# Initialize plant
plant = ModelHelper([Nx], [Nu])

# steady state optimization
xguess = xs
uguess = us
xs_opt,us_opt = plant.ss_optimization(Nx, Nu, metricCalc, xguess, uguess)
x0 = copy.deepcopy(xs_opt)
uk = copy.deepcopy(us_opt)

# Build integrator for plant simulation
# Several options depending on packages selected
# 1. scipy solve_ivp or odeint on continuous-time (ct) odes
# 2. casadi integrator on ct odes
# 3. mpctools DiscreteSimulator on ct odes
# 4. casadi Function or mpctools getCasadiFunc on discrete-time (dt) odes

# Declare Symbolic variables for system
t = SX.sym('t')  # Time
x = SX.sym('x', Nx)  # Differential states
u = SX.sym('u', Nu)  # Control inputs

ode = plant.getODE(x, u)
ode = {'x': x, 'p': vertcat(u), 'ode': ode}
opts = {"tf": dt, "abstol": 1e-5}  # interval length
plant_sim = integrator('I', 'cvodes', ode, opts)

# Create list for storing trajectories
tt = [] # np.zeros(Nt)
X = [] # np.zeros((Nt, Nx))
U = [] # np.zeros((Nt, Nu))
Y = []

# initial state
X += [x0]

# Run simulation
totaltime = -time.time()

for i in range(Nt):
    print(i)

    tt += [i*dt/60]  # Current sample time
    U += [uk]

    sol = plant_sim(x0=X[i], p=vertcat(U[i]))

    X += [sol["xf"].full()[:, 0]]
    Y += [metricCalc(X[i], U[i])]

tt += [i*dt/60]
totaltime += time.time()
print('Simulation complete. Total time: %5.3g s' % totaltime)

# convert results to numpy arrays
X = np.array(X)
U = np.array(U)
Y = np.array(Y)

# Plot
plt.figure()
plt.plot(X)
plt.show()

plt.figure()
plt.plot(U)
plt.show()

print('done')